
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

p = {}
p['pres'] = 0.3e5 # pressure of layer in Pa
p['ppm'] = 10e6 # convert mol/mol to ppm

#
#
p['Ks1_ref'] = 1.3e-2 # mol/L at 298 K
p['Ks2_ref'] = 6.6e-8 # mol/L at 298 K
p['Tref'] = 298.0 
p['R_kcal'] = 1.9872e-3 # kcal/mol/K
p['H_SO2_ref'] = 1.23/1e5 # mol/kg/Pa = mol/L/Pa = M/Pa
p['Tlay'] = p['Tref'] # temperature of layer in K
p['Rstar'] = 8.314462 # ideal gas constant J/K/mol
p['n_air'] = p['pres']/(p['Rstar']*p['Tlay']) # local air molar den [mol/m3]

#
f_SO2_0 = 10/p['ppm']
p['Kedd'] = 1e-4*1e4 # eddy m2/s
p['Hsca'] = 20e3 # layer height in m
p['Lw'] = 3e-3 # liquid water content L/m3

# fix pH & timescales
p['tau_diss'] = 50.0
p['tau_eddy'] = p['Hsca']**2/(2*p['Kedd'])
#p['tau_rain'] = 86400.0
p['tau_rain'] = 10.0e5*3e-3/(1340/2260000.0) # kind of long? 
p['tau_ox'] = 86400.0 # oxidation time 


# matlab defined n_SO2_here
# matlab defined disp_res=1 here
# matlab called function here



def aqueous_box_fn(p,n_SO2_0,pH_fn,disp_res):
    Ks1 = lambda T: p['Ks1_ref']*np.exp(-1960.0*(1/p['Tref'] - 1/T))
    Ks2 = lambda T: p['Ks2_ref']*np.exp(-1500.0*(1/p['Tref'] - 1/T))
    H_SO2 = lambda T: p['H_SO2_ref']*np.exp(-(6.25/p['R_kcal'])*(1/p['Tref'] - 1/T))
    aH = lambda pH: 10**(-pH)
    fn_S_IV = lambda pH, T: 1 + Ks1(T) / aH(pH) + Ks1(T) * Ks2(T) / aH(pH)**2
    n_SO2_eq = lambda pH, T, n_S: n_S / (1 + p['Lw'] * H_SO2(T) * fn_S_IV(pH, T) * p['Rstar'] * T)
    #X0 = [n_SO2_0, 1e-11, 1e-11]
    X0 = [n_SO2_0, 1e-8, 1e-6]

    dnSO2dt_diff = lambda n_SO2: (n_SO2_0 - n_SO2)/p['tau_eddy']
    dnSO2dt_aq = lambda  n_SO2,n_S_IV,n_S_VI: (n_SO2_eq(pH_fn,p['Tlay'],n_S_IV+n_SO2) - n_SO2)/p['tau_diss']
    dnSIVdt_ox = lambda n_SIV: n_SIV/p['tau_ox']
    dnSdt_precip = lambda n_S: -1*n_S/p['tau_rain']

    dXdt_1 = lambda X: dnSO2dt_diff(X[0]) + dnSO2dt_aq(X[0], X[1], X[2])
    dXdt_2 = lambda X: -dnSO2dt_aq(X[0], X[1], X[2]) - dnSIVdt_ox(X[1]) + dnSdt_precip(X[1])
    dXdt_3 = lambda X: dnSIVdt_ox(X[1]) + dnSdt_precip(X[2])

    def dXdt(X, t):
        return [dXdt_1(X), dXdt_2(X), dXdt_3(X)]

    nt = 1e3
    t_a = np.linspace(0,300*86400,int(nt));
    X=X0

    X_a = np.zeros((len(t_a),len(X0)))
    X_a[1,:] = X
    
    #X = odeint(dXdt,X0,t_a)
    solver_options = {'rtol': 1e-8}
    for it in range(1, int(nt)):
        print('starting timestep ',it)
        print('aH',aH)
        print('n_SO2_eq',n_SO2_eq)
        print(dnSO2dt_diff,dnSO2dt_aq,dnSIVdt_ox,dnSdt_precip)
        X = odeint(dXdt, X, [t_a[it - 1], t_a[it]], hmax=1e4, rtol=solver_options['rtol'])[-1, :]
        X_a[it, :] = X
        
    X_f = X_a[-1,:]
    f_f = X_f/p['n_air']

    return f_f

# commented out pH dependence from before
ox = (10**np.arange(-1,2.02,0.3))*86400.0
edd = (10**np.arange(3,8.1,0.5))*1e-4
phs = np.array([1.5,2.5,4.0,5.0,6.0])
result = np.zeros(((len(edd)),len(ox),len(phs),3))

p['Kedd'] = 1e-4*1e4 # eddy m2/s
p['tau_eddy'] = p['Hsca']**2/(2*p['Kedd'])

for iox in range(0,len(ox)):
    for ik in range(0,len(edd)):
        for iphs in range(0,len(phs)):
            #p['tau_rain'] = rain[irain]
            p['tau_ox'] = ox[iox]
            p['tau_eddy'] = p['Hsca']**2/(2*edd[ik])
            p['tau_diss'] = 1.0

            pH_fn = phs[iphs]
            n_SO2_0 = f_SO2_0*p['n_air'] #; % layer 0 SO2 molar density [mol/m3]
            Xa = aqueous_box_fn(p,n_SO2_0,pH_fn,1)
            result[ik,iox,iphs,:] = Xa
#            phs[ik,iox,iphs] = pH_fn(Xa[-1])

            
dat = {'result':result,'phs':phs,'ox':ox,'edd':edd}
import pickle as pickle
with open('resultgridphfix-10ppm.pickle','wb') as handle:
    dat = pickle.dump(dat,handle,protocol=pickle.HIGHEST_PROTOCOL)

'''
#X_f = X_a(end,:);
#f_f = X_f/p.n_air;
#x1_ana = n_SO2_0 - ((p.tau_rain + p.tau_ox)/p.tau_rain)*(p.tau_eddy/p.tau_rain)*X_a(end,3)#

f_f(iR,:) = aqueous_box_fn(p,n_SO2_0,pH_fn,disp_res);




    Ks1 = lambda T: p['Ks1_ref'] * np.exp(-1960 * (1 / p['Tref'] - 1 / T))
    Ks2 = lambda T: p['Ks2_ref'] * np.exp(-1500 * (1 / p['Tref'] - 1 / T))
    H_SO2 = lambda T: p['H_SO2_ref'] * np.exp(-(6.25 / p['R_kcal']) * (1 / p['Tref'] - 1 / T))

    aH = lambda pH: 10**(-pH)
    fn_S_IV = lambda pH, T: 1 + Ks1(T) / aH(pH) + Ks1(T) * Ks2(T) / aH(pH)**2
    n_SO2_eq = lambda pH, T, n_S: n_S / (1 + p['Lw'] * H_SO2(T) * fn_S_IV(pH, T) * p['Rstar'] * T)
'''
